{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   # see issue #152\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "from __future__ import print_function\n",
    "from six.moves import xrange\n",
    "import tensorflow.contrib.slim as slim\n",
    "import os\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import tensorflow.contrib.layers as ly\n",
    "from load_svhn import load_svhn\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def lrelu(x, leak=0.3, name=\"lrelu\"):\n",
    "    with tf.variable_scope(name):\n",
    "        f1 = 0.5 * (1 + leak)\n",
    "        f2 = 0.5 * (1 - leak)\n",
    "        return f1 * x + f2 * abs(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "z_dim = 128\n",
    "learning_rate_ger = 5e-5\n",
    "learning_rate_dis = 5e-5\n",
    "device = '/gpu:0'\n",
    "# img size\n",
    "s = 32\n",
    "# update Citers times of critic in one iter(unless i < 25 or i % 500 == 0, i is iterstep)\n",
    "Citers = 5\n",
    "# the upper bound and lower bound of parameters in critic\n",
    "clamp_lower = -0.01\n",
    "clamp_upper = 0.01\n",
    "# whether to use mlp or dcgan stucture\n",
    "is_mlp = False\n",
    "# whether to use adam for parameter update, if the flag is set False, use tf.train.RMSPropOptimizer\n",
    "# as recommended in paper\n",
    "is_adam = False\n",
    "# whether to use SVHN or MNIST, set false and MNIST is used\n",
    "is_svhn = False\n",
    "channel = 3 if is_svhn is True else 1\n",
    "# 'gp' for gp WGAN and 'regular' for vanilla\n",
    "mode = 'gp'\n",
    "# if 'gp' is chosen the corresponding lambda must be filled\n",
    "lam = 10.\n",
    "s2, s4, s8, s16 =\\\n",
    "    int(s / 2), int(s / 4), int(s / 8), int(s / 16)\n",
    "# hidden layer size if mlp is chosen, ignore if otherwise\n",
    "ngf = 64\n",
    "ndf = 64\n",
    "# directory to store log, including loss and grad_norm of generator and critic\n",
    "log_dir = './log_wgan'\n",
    "ckpt_dir = './ckpt_wgan'\n",
    "if not os.path.exists(ckpt_dir):\n",
    "    os.makedirs(ckpt_dir)\n",
    "# max iter step, note the one step indicates that a Citers updates of critic and one update of generator\n",
    "max_iter_step = 20000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator_conv(z):\n",
    "    train = ly.fully_connected(\n",
    "        z, 4 * 4 * 512, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "    train = tf.reshape(train, (-1, 4, 4, 512))\n",
    "    train = ly.conv2d_transpose(train, 256, 3, stride=2,\n",
    "                                activation_fn=tf.nn.relu, normalizer_fn=ly.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))\n",
    "    train = ly.conv2d_transpose(train, 128, 3, stride=2,\n",
    "                                activation_fn=tf.nn.relu, normalizer_fn=ly.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))\n",
    "    train = ly.conv2d_transpose(train, 64, 3, stride=2,\n",
    "                                activation_fn=tf.nn.relu, normalizer_fn=ly.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))\n",
    "    train = ly.conv2d_transpose(train, channel, 3, stride=1,\n",
    "                                activation_fn=tf.nn.tanh, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02))\n",
    "    print(train.name)\n",
    "    return train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator_mlp(z):\n",
    "    train = ly.fully_connected(\n",
    "        z, 4 * 4 * 512, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "    train = ly.fully_connected(\n",
    "        train, ngf, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "    train = ly.fully_connected(\n",
    "        train, ngf, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "    train = ly.fully_connected(\n",
    "        train, s*s*channel, activation_fn=tf.nn.tanh, normalizer_fn=ly.batch_norm)\n",
    "    train = tf.reshape(train, tf.stack([batch_size, s, s, channel]))\n",
    "    return train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def critic_conv(img, reuse=False):\n",
    "    with tf.variable_scope('critic') as scope:\n",
    "        if reuse:\n",
    "            scope.reuse_variables()\n",
    "        size = 64\n",
    "        img = ly.conv2d(img, num_outputs=size, kernel_size=3,\n",
    "                        stride=2, activation_fn=lrelu)\n",
    "        img = ly.conv2d(img, num_outputs=size * 2, kernel_size=3,\n",
    "                        stride=2, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "        img = ly.conv2d(img, num_outputs=size * 4, kernel_size=3,\n",
    "                        stride=2, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "        img = ly.conv2d(img, num_outputs=size * 8, kernel_size=3,\n",
    "                        stride=2, activation_fn=lrelu, normalizer_fn=ly.batch_norm)\n",
    "        logit = ly.fully_connected(tf.reshape(\n",
    "            img, [batch_size, -1]), 1, activation_fn=None)\n",
    "    return logit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def critic_mlp(img, reuse=False):\n",
    "    with tf.variable_scope('critic') as scope:\n",
    "        if reuse:\n",
    "            scope.reuse_variables()\n",
    "        size = 64\n",
    "        img = ly.fully_connected(tf.reshape(\n",
    "            img, [batch_size, -1]), ngf, activation_fn=tf.nn.relu)\n",
    "        img = ly.fully_connected(img, ngf,\n",
    "            activation_fn=tf.nn.relu)\n",
    "        img = ly.fully_connected(img, ngf,\n",
    "            activation_fn=tf.nn.relu)\n",
    "        logit = ly.fully_connected(img, 1, activation_fn=None)\n",
    "    return logit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def build_graph():\n",
    "#     z = tf.placeholder(tf.float32, shape=(batch_size, z_dim))\n",
    "    noise_dist = tf.contrib.distributions.Normal(0., 1.)\n",
    "    z = noise_dist.sample((batch_size, z_dim))\n",
    "    generator = generator_mlp if is_mlp else generator_conv\n",
    "    critic = critic_mlp if is_mlp else critic_conv\n",
    "    with tf.variable_scope('generator'):\n",
    "        train = generator(z)\n",
    "    real_data = tf.placeholder(\n",
    "        dtype=tf.float32, shape=(batch_size, 32, 32, channel))\n",
    "    true_logit = critic(real_data)\n",
    "    fake_logit = critic(train, reuse=True)\n",
    "    c_loss = tf.reduce_mean(fake_logit - true_logit)\n",
    "    if mode is 'gp':\n",
    "        alpha_dist = tf.contrib.distributions.Uniform(low=0., high=1.)\n",
    "        alpha = alpha_dist.sample((batch_size, 1, 1, 1))\n",
    "        interpolated = real_data + alpha*(train-real_data)\n",
    "        inte_logit = critic(interpolated, reuse=True)\n",
    "        gradients = tf.gradients(inte_logit, [interpolated,])[0]\n",
    "        grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))\n",
    "        gradient_penalty = tf.reduce_mean((grad_l2-1)**2)\n",
    "        gp_loss_sum = tf.summary.scalar(\"gp_loss\", gradient_penalty)\n",
    "        grad = tf.summary.scalar(\"grad_norm\", tf.nn.l2_loss(gradients))\n",
    "        c_loss += lam*gradient_penalty\n",
    "    g_loss = tf.reduce_mean(-fake_logit)\n",
    "    g_loss_sum = tf.summary.scalar(\"g_loss\", g_loss)\n",
    "    c_loss_sum = tf.summary.scalar(\"c_loss\", c_loss)\n",
    "    img_sum = tf.summary.image(\"img\", train, max_outputs=10)\n",
    "    theta_g = tf.get_collection(\n",
    "        tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')\n",
    "    theta_c = tf.get_collection(\n",
    "        tf.GraphKeys.TRAINABLE_VARIABLES, scope='critic')\n",
    "    counter_g = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)\n",
    "    opt_g = ly.optimize_loss(loss=g_loss, learning_rate=learning_rate_ger,\n",
    "                    optimizer=partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) if is_adam is True else tf.train.RMSPropOptimizer, \n",
    "                    variables=theta_g, global_step=counter_g,\n",
    "                    summaries = ['gradient_norm'])\n",
    "    counter_c = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)\n",
    "    opt_c = ly.optimize_loss(loss=c_loss, learning_rate=learning_rate_dis,\n",
    "                    optimizer=partial(tf.train.AdamOptimizer, beta1=0.5, beta2=0.9) if is_adam is True else tf.train.RMSPropOptimizer, \n",
    "                    variables=theta_c, global_step=counter_c,\n",
    "                    summaries = ['gradient_norm'])\n",
    "    if mode is 'regular':\n",
    "        clipped_var_c = [tf.assign(var, tf.clip_by_value(var, clamp_lower, clamp_upper)) for var in theta_c]\n",
    "        # merge the clip operations on critic variables\n",
    "        with tf.control_dependencies([opt_c]):\n",
    "            opt_c = tf.tuple(clipped_var_c)\n",
    "    if not mode in ['gp', 'regular']:\n",
    "        raise(NotImplementedError('Only two modes'))\n",
    "    return opt_g, opt_c, real_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def main():\n",
    "    if is_svhn is True:\n",
    "        dataset = load_svhn()\n",
    "    else:\n",
    "        dataset = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "    with tf.device(device):\n",
    "        opt_g, opt_c, real_data = build_graph()\n",
    "    merged_all = tf.summary.merge_all()\n",
    "    saver = tf.train.Saver()\n",
    "    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)\n",
    "    config.gpu_options.allow_growth = True\n",
    "    config.gpu_options.per_process_gpu_memory_fraction = 0.8\n",
    "    def next_feed_dict():\n",
    "        train_img = dataset.train.next_batch(batch_size)[0]\n",
    "        train_img = 2*train_img-1\n",
    "        if is_svhn is not True:\n",
    "            train_img = np.reshape(train_img, (-1, 28, 28))\n",
    "            npad = ((0, 0), (2, 2), (2, 2))\n",
    "            train_img = np.pad(train_img, pad_width=npad,\n",
    "                               mode='constant', constant_values=-1)\n",
    "            train_img = np.expand_dims(train_img, -1)\n",
    "        feed_dict = {real_data: train_img}\n",
    "        return feed_dict\n",
    "    with tf.Session(config=config) as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)\n",
    "        for i in range(max_iter_step):\n",
    "            if i < 25 or i % 500 == 0:\n",
    "                citers = 100\n",
    "            else:\n",
    "                citers = Citers\n",
    "            for j in range(citers):\n",
    "                feed_dict = next_feed_dict()\n",
    "                if i % 100 == 99 and j == 0:\n",
    "                    run_options = tf.RunOptions(\n",
    "                        trace_level=tf.RunOptions.FULL_TRACE)\n",
    "                    run_metadata = tf.RunMetadata()\n",
    "                    _, merged = sess.run([opt_c, merged_all], feed_dict=feed_dict,\n",
    "                                         options=run_options, run_metadata=run_metadata)\n",
    "                    summary_writer.add_summary(merged, i)\n",
    "                    summary_writer.add_run_metadata(\n",
    "                        run_metadata, 'critic_metadata {}'.format(i), i)\n",
    "                else:\n",
    "                    sess.run(opt_c, feed_dict=feed_dict)                \n",
    "            feed_dict = next_feed_dict()\n",
    "            if i % 100 == 99:\n",
    "                _, merged = sess.run([opt_g, merged_all], feed_dict=feed_dict,\n",
    "                     options=run_options, run_metadata=run_metadata)\n",
    "                summary_writer.add_summary(merged, i)\n",
    "                summary_writer.add_run_metadata(\n",
    "                    run_metadata, 'generator_metadata {}'.format(i), i)\n",
    "            else:\n",
    "                sess.run(opt_g, feed_dict=feed_dict)                \n",
    "            if i % 1000 == 999:\n",
    "                saver.save(sess, os.path.join(\n",
    "                    ckpt_dir, \"model.ckpt\"), global_step=i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "\n",
    "main()"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
